#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File              : gift_init.py
# Author            : Yiting Liu, Yibo Lin <yibolin@pku.edu.cn>
# Date              : 12.16.2024
# Last Modified Date: 12.17.2024
# Last Modified By  : Yibo Lin <yibolin@pku.edu.cn>

import torch
from torch.autograd import Function
from torch import nn
import numpy as np
import scipy 
import dreamplace.ops.gift_init.utils_gpu.util as util
import dreamplace.ops.gift_init.utils_gpu.mix_frequency_filter as mix_frequency_filter

import dreamplace.ops.gift_init.gift_init_cpp as gift_init_cpp

import logging
logger = logging.getLogger(__name__)

import time
import pdb

class GiFtInit(nn.Module):
    """ 
    @brief Compute initial position using GiFt technique published at ICCAD 2024. 
    Yiting Liu et al., The Power of Graph Signal Processing for Chip Placement Acceleration, ICCAD 2024. 
    """
    def __init__(self,
                 flat_netpin,
                 netpin_start,
                 pin2node_map,
                 net_weights,
                 net_mask, 
                 num_nodes, 
                 num_movable_nodes, 
                 scale = 0.7):
        """
        @brief initialization 
        @param flat_netpin flat netpin map, length of #pins 
        @param netpin_start starting index in netpin map for each net, length of #nets+1, the last entry is #pins  
        @param pin2node_map pin2node map 
        @param net_weights weight of nets 
        @param net_mask whether to compute wirelength, 1 means to compute, 0 means to ignore; users should guarantee invalid nets are filtered out  
        @param scale the distribution range of the initial locations generated by GiFt (customizable)
        """
        super(GiFtInit, self).__init__()

        self.num_nodes = num_nodes 
        self.num_movable_nodes = num_movable_nodes
        self.num_fixed_nodes = num_nodes - num_movable_nodes
        self.scale = scale 

        logger.info('Construct adjacency matrix using clique model')

        ret = gift_init_cpp.adj_matrix_forward(flat_netpin.cpu(), netpin_start.cpu(), pin2node_map.cpu(), net_weights.cpu(), net_mask.cpu(), num_nodes)
        data = ret[0]
        row = ret[1]
        col = ret[2]
        dtype = np.float32 if net_weights.dtype == torch.float32 else np.float64
        self.adj_mat = scipy.sparse.coo_matrix((data.numpy(), (row.numpy(), col.numpy())), dtype=dtype, shape=(num_nodes, num_nodes))

        logger.info('Done matrix construction')

    def forward(self, pos):
        # --------------------Generate the initial locations of movable cells------------------------#
        with torch.no_grad(): 
            pos_t = pos.view([2, -1]).t().cpu().numpy()
            fixed_cell_location = pos_t[self.num_movable_nodes:self.num_movable_nodes+self.num_fixed_nodes]
            random_initial = util.generate_initial_locations(fixed_cell_location, self.num_movable_nodes, self.scale)
            random_initial = np.concatenate((random_initial, fixed_cell_location), 0)
            random_initial = torch.from_numpy(random_initial).float().to(pos.device)

            # ----------- low-pass filter ------------------------#
            start = time.time()
            low_pass_filter = mix_frequency_filter.GiFt_GPU(self.adj_mat, pos.device)
            low_pass_filter.train(4)
            location_low = low_pass_filter.get_cell_position(4, random_initial)
            logger.info('finish initial placement')

            # ----------- m-pass filter ------------------------#
            low_pass_filter.train(4)
            location_m = low_pass_filter.get_cell_position(2, random_initial)
            logger.info('finish m-pass filter!')

            # ----------- h-pass filter ------------------------#
            low_pass_filter.train(2)
            location_h = low_pass_filter.get_cell_position(2, random_initial)
            logger.info('finish h-pass filter!')
            end = time.time()
            logger.info('total time %g sec', end - start)

            # ------------ final location  ------------------------#
            location = 0.2 * location_low + 0.7 * location_m + 0.1 * location_h
            # location = location_m

            # ---------------- plot ------------------------#
            #location = location.cpu()
            #location = location.detach().numpy()
            #plot_pos(location, fixed_cell_location)

            # need to transpose back 
            return location.t()

